#!/usr/bin/env python3
# coding: utf-8
# File: question_classifier.py
# Author: lhy<lhy_in_blcu@126.com,https://huangyong.github.io>
# Date: 18-10-4

import os
from py2neo import Graph
# import ahocorasick


class QuestionClassifier:
    def __init__(self):
        self.g = Graph(host="localhost",  # neo4j 搭载服务器的ip地址，ifconfig可获取到  #192.168.1.221
                       http_port=7474,  # neo4j 服务器监听的端口号
                       user="neo4j",  # 数据库user name，如果没有更改过，应该是neo4j
                       password="s3cr3t")
        # cur_dir = '/'.join(os.path.abspath(__file__).split('/')[:-1])
        # #　特征词路径
        # self.disease_path = os.path.join(cur_dir, 'dict/disease.txt')
        # self.department_path = os.path.join(cur_dir, 'dict/department.txt')
        # self.check_path = os.path.join(cur_dir, 'dict/check.txt')
        # self.drug_path = os.path.join(cur_dir, 'dict/drug.txt')
        # self.food_path = os.path.join(cur_dir, 'dict/food.txt')
        # self.producer_path = os.path.join(cur_dir, 'dict/producer.txt')
        # self.symptom_path = os.path.join(cur_dir, 'dict/symptom.txt')
        # self.deny_path = os.path.join(cur_dir, 'dict/deny.txt')
        # # 加载特征词
        # self.disease_wds= [i.strip() for i in open(self.disease_path, encoding='utf_8') if i.strip()]
        # self.department_wds= [i.strip() for i in open(self.department_path, encoding='utf_8') if i.strip()]
        # self.check_wds= [i.strip() for i in open(self.check_path, encoding='utf_8') if i.strip()]
        # self.drug_wds= [i.strip() for i in open(self.drug_path, encoding='utf_8') if i.strip()]
        # self.food_wds= [i.strip() for i in open(self.food_path, encoding='utf_8') if i.strip()]
        # self.producer_wds= [i.strip() for i in open(self.producer_path, encoding='utf_8') if i.strip()]
        # self.symptom_wds= [i.strip() for i in open(self.symptom_path, encoding='utf_8') if i.strip()]
        # self.region_words = set(self.department_wds + self.disease_wds + self.check_wds + self.drug_wds + self.food_wds + self.producer_wds + self.symptom_wds)
        # self.deny_words = [i.strip() for i in open(self.deny_path, encoding='utf_8') if i.strip()]
        # # 构造领域actree
        # self.region_tree = self.build_actree(list(self.region_words))
        # # 构建词典
        # self.wdtype_dict = self.build_wdtype_dict()
        # # 问句疑问词
        # self.self_type = ['什么是', '是什么']
        # self.symptom_qwds = ['症状', '表征', '现象', '症候', '表现']
        # self.cause_qwds = self._get_qwds('cause') #['原因','成因', '为什么', '怎么会', '怎样才', '咋样才', '怎样会', '如何会', '为啥', '为何', '如何才会', '怎么才会', '会导致', '会造成']
        # self.acompany_qwds = ['并发症', '并发', '一起发生', '一并发生', '一起出现', '一并出现', '一同发生', '一同出现', '伴随发生', '伴随', '共现']
        # self.food_qwds = ['饮食', '饮用', '吃', '食', '伙食', '膳食', '喝', '菜' ,'忌口', '补品', '保健品', '食谱', '菜谱', '食用', '食物','补品']
        # self.drug_qwds = ['药', '药品', '用药', '胶囊', '口服液', '炎片']
        # self.prevent_qwds = ['预防', '防范', '抵制', '抵御', '防止','躲避','逃避','避开','免得','逃开','避开','避掉','躲开','躲掉','绕开',
        #                      '怎样才能不', '怎么才能不', '咋样才能不','咋才能不', '如何才能不',
        #                      '怎样才不', '怎么才不', '咋样才不','咋才不', '如何才不',
        #                      '怎样才可以不', '怎么才可以不', '咋样才可以不', '咋才可以不', '如何可以不',
        #                      '怎样才可不', '怎么才可不', '咋样才可不', '咋才可不', '如何可不']
        # self.lasttime_qwds = ['周期', '多久', '多长时间', '多少时间', '几天', '几年', '多少天', '多少小时', '几个小时', '多少年']
        # self.cureway_qwds = ['怎么治疗', '如何医治', '怎么医治', '怎么治', '怎么医', '如何治', '医治方式', '疗法', '咋治', '怎么办', '咋办', '咋治']
        # self.cureprob_qwds = ['多大概率能治好', '多大几率能治好', '治好希望大么', '几率', '几成', '比例', '可能性', '能治', '可治', '可以治', '可以医']
        # self.easyget_qwds = ['易感人群', '容易感染', '易发人群', '什么人', '哪些人', '感染', '染上', '得上']
        # self.check_qwds = ['检查', '检查项目', '查出', '检查', '测出', '试出']
        # self.belong_qwds = ['属于什么科', '属于', '什么科', '科室']
        # self.cure_qwds = ['治疗什么', '治啥', '治疗啥', '医治啥', '治愈啥', '主治啥', '主治什么', '有什么用', '有何用', '用处', '用途',
        #                   '有什么好处', '有什么益处', '有何益处', '用来', '用来做啥', '用来作甚', '需要', '要']

        print('model init finished ......')

        return

    # 获取关键词
    def _get_qwds(self, label):
        return [i['n.name'] for i in self.g.run('match (n:{}) return n.name'.format(label))]

    '''分类主函数'''

    def classify(self, question,session):
        data = {}
        sql_name = 'match (n) where not n:question_type and  "{}" contains n.name return n.name as name,n.label as label'.format(
            question)
        print(sql_name)
        sql_label = 'match (m) where "{}" contains m.name return m.label as label'.format(question)
        data_name = self.g.run(sql_name).data()
        data_label = self.g.run(sql_label).data()
        dict_type = {i['name']: i['label'] for i in data_name}
        stop_wds = []
        for wd1 in dict_type.keys():
            for wd2 in dict_type.keys():
                if wd1 in wd2 and wd1 != wd2:
                    stop_wds.append(wd1)

        final_wds = [i for i in dict_type.keys() if i not in stop_wds]
        medical_dict = {i: dict_type.get(i) for i in final_wds}
        print("medical_dict:",medical_dict)
        print( "data_label:", data_label )
        print( "data_name:", data_name )
        if not medical_dict:
            # if not session['entity']:
            return {}
            # else:
            #     medical_dict = session['entity']
        session['entity'] = medical_dict
        if len(data_label)==0 or len(data_label) == len(data_name):
            return {}
        session['entity'] = medical_dict
        data['args'] = medical_dict
        print("data1111:",data)
        # 收集问句当中所涉及到的实体类型
        types = medical_dict.values()
        # for type_ in medical_dict.values():
        #     types += type_

        # question_type = 'others'
        #
        question_types = list(set([i['label'] for i in data_label]))

        # if len(data_label) == len(data_name):
        #     if 'symptom' in types:
        #         question_types = ['symptom_disease']
        #     if 'diseases' in types:
        #         question_types = ['disease_desc', 'disease_desc2', 'belong_disease', 'disease_acompany']
        # data['question_types'] = []

        if ('diseases' in types):
            if 'disease_drug' in question_types and 'disease_food' in question_types:
                question_types.remove('disease_food')
            elif 'food_deny' in question_types and 'disease_food' in question_types:
                question_types.remove('food_deny')
                question_types.remove('disease_food')
                question_types.append('disease_not_food')
            elif 'food_deny' not in question_types and 'disease_food' in question_types:
                question_types.remove('disease_food')
                question_types.append('disease_do_food')
        if ('symptom' in types):
            if 'disease_symptom' in question_types:
                question_types.remove('disease_symptom')
                question_types.append('symptom_disease')
            elif "disease_cause" in question_types:
                question_types.append('symptom_disease')
        if 'food' in types:
            if 'disease_food' in question_types and 'drug_disease' in question_types:
                if 'food_deny' in question_types:
                    question_types.remove('disease_food')
                    question_types.remove('drug_disease')
                    question_types.append('food_not_disease')
                else:
                    question_types.remove('disease_food')
                    question_types.remove('drug_disease')
                    question_types.append('food_do_disease')
        data['question_types'] = question_types
        #
        # # 症状
        # if self.check_words(self.symptom_qwds, question) and ('disease' in types):
        #     question_type = 'disease_symptom'
        #     question_types.append(question_type)
        #
        # if self.check_words(self.symptom_qwds, question) and ('symptom' in types):
        #     question_type = 'symptom_disease'
        #     question_types.append(question_type)
        #
        # # 原因
        # if self.check_words(self.cause_qwds, question) and ('disease' in types):
        #     question_type = 'disease_cause'
        #     question_types.append(question_type)
        # # 并发症
        # if self.check_words(self.acompany_qwds, question) and ('disease' in types):
        #     question_type = 'disease_acompany'
        #     question_types.append(question_type)
        #
        # # 推荐食品
        # if self.check_words(self.food_qwds, question) and 'disease' in types:
        #     deny_status = self.check_words(self.deny_words, question)
        #     if deny_status:
        #         question_type = 'disease_not_food'
        #     else:
        #         question_type = 'disease_do_food'
        #     question_types.append(question_type)
        #
        # #已知食物找疾病
        # if self.check_words(self.food_qwds+self.cure_qwds, question) and 'food' in types:
        #     deny_status = self.check_words(self.deny_words, question)
        #     if deny_status:
        #         question_type = 'food_not_disease'
        #     else:
        #         question_type = 'food_do_disease'
        #     question_types.append(question_type)
        #
        # # 推荐药品
        # if self.check_words(self.drug_qwds, question) and 'disease' in types:
        #     question_type = 'disease_drug'
        #     question_types.append(question_type)
        #
        # # 属于科室
        # if self.check_words(self.belong_qwds, question) and 'disease' in types:
        #     question_type = 'belong_disease'
        #     question_types.append(question_type)
        #
        # # 药品治啥病
        # if self.check_words(self.cure_qwds, question) and 'drug' in types:
        #     question_type = 'drug_disease'
        #     question_types.append(question_type)
        #
        # # 疾病接受检查项目
        # if self.check_words(self.check_qwds, question) and 'disease' in types:
        #     question_type = 'disease_check'
        #     question_types.append(question_type)
        #
        # # 已知检查项目查相应疾病
        # if self.check_words(self.check_qwds+self.cure_qwds, question) and 'check' in types:
        #     question_type = 'check_disease'
        #     question_types.append(question_type)
        #
        # #　症状防御
        # if self.check_words(self.prevent_qwds, question) and 'disease' in types:
        #     question_type = 'disease_prevent'
        #     question_types.append(question_type)
        #
        # # 疾病医疗周期
        # if self.check_words(self.lasttime_qwds, question) and 'disease' in types:
        #     question_type = 'disease_lasttime'
        #     question_types.append(question_type)
        #
        # # 疾病治疗方式
        # if self.check_words(self.cureway_qwds, question) and 'disease' in types:
        #     question_type = 'disease_cureway'
        #     question_types.append(question_type)
        #
        # # 疾病治愈可能性
        # if self.check_words(self.cureprob_qwds, question) and 'disease' in types:
        #     question_type = 'disease_cureprob'
        #     question_types.append(question_type)
        #
        # # 疾病易感染人群
        # if self.check_words(self.easyget_qwds, question) and 'disease' in types :
        #     question_type = 'disease_easyget'
        #     question_types.append(question_type)
        #
        # # 若没有查到相关的外部查询信息，那么则将该疾病的描述信息返回
        # if self.check_words(self.self_type, question)  and 'disease' in types:
        #     question_types = ['disease_desc','disease_desc2','belong_disease','disease_acompany']
        #
        #
        # # 若没有查到相关的外部查询信息，那么则将该疾病的描述信息返回
        # if self.check_words(self.self_type, question) and 'symptom' in types:
        #     question_types = ['symptom_disease']
        #
        #
        # # 将多个分类结果进行合并处理，组装成一个字典
        # data['question_types'] = question_types
        print("data:",data)
        return data

    '''构造词对应的类型'''

    # def build_wdtype_dict(self):
    #     wd_dict = dict()
    #     for wd in self.region_words:
    #         wd_dict[wd] = []
    #         if wd in self.disease_wds:
    #             wd_dict[wd].append('disease')
    #         if wd in self.department_wds:
    #             wd_dict[wd].append('department')
    #         if wd in self.check_wds:
    #             wd_dict[wd].append('check')
    #         if wd in self.drug_wds:
    #             wd_dict[wd].append('drug')
    #         if wd in self.food_wds:
    #             wd_dict[wd].append('food')
    #         if wd in self.symptom_wds:
    #             wd_dict[wd].append('symptom')
    #         if wd in self.producer_wds:
    #             wd_dict[wd].append('producer')
    #     return wd_dict
    #
    # '''构造actree，加速过滤'''
    #
    # def build_actree(self, wordlist):
    #     actree = ahocorasick.Automaton()
    #     for index, word in enumerate(wordlist):
    #         actree.add_word(word, (index, word))
    #     actree.make_automaton()
    #     return actree

    '''问句过滤'''

    # def check_medical(self, question):
    #     region_wds = []
    #     for i in self.region_tree.iter(question):
    #         wd = i[1][1]
    #         region_wds.append(wd)
    #     stop_wds = []
    #     for wd1 in region_wds:
    #         for wd2 in region_wds:
    #             if wd1 in wd2 and wd1 != wd2:
    #                 stop_wds.append(wd1)
    #     final_wds = [i for i in region_wds if i not in stop_wds]
    #     final_dict = {i: self.wdtype_dict.get(i) for i in final_wds}
    #
    #     return final_dict

    '''基于特征词进行分类'''

    # def check_words(self, wds, sent):
    #     for wd in wds:
    #         if wd in sent:
    #             return True
    #     return False
    #

if __name__ == '__main__':
    handler = QuestionClassifier()
    while 1:
        question = input('input an question:')
        data = handler.classify(question)
        print(data)
